#!/usr/bin/env python
# vim: tabstop=4 shiftwidth=4 softtabstop=4

# Copyright 2014 Intel
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import re
import os
import json
import sys
from vsm.manifest.parser import ManifestParser
from vsm.manifest.wsgi_client import WSGIClient
from vsm import utils
from vsm import flags
from vsm import ipcalc
from vsm import db
from vsm import context

FLAGS = flags.FLAGS

def error(msg):
    print '---------------ERROR----------------'
    print 'Errors below is caused by your manifest file'
    print 'Please check your manifest file!'

    print '------------------------------------'
    if isinstance(msg, list):
        for n in msg:
            print n
    else:
        print msg
    print '------------------------------------'
    sys.exit(1)

def _write_etc_hosts(hosts_content):
    utils.write_file_as_root(FLAGS.etc_hosts, hosts_content, "w")

class ManifestChecker(object):
    def __init__(self, file_path):
        self._context = context.get_admin_context()
        self._file_path = file_path
        self._smp = ManifestParser(file_path)
        self._smp._file_path = file_path
        self._info = self._smp.format_to_json(check_manifest_tag=True)
        self._lines = open(file_path).readlines()
        self._server_host = self._info['vsm_controller_ip']
        print self._info
        try:
            self._sender = WSGIClient(self._server_host,
                                  self._info['auth_key'])

            self._rec_data = self._sender.index()

            hosts_content = self._rec_data['etc_hosts']
            self._storage_class_list = self._rec_data['storage_class']
            _write_etc_hosts(hosts_content)
            print json.dumps(self._rec_data, sort_keys=True, indent=2)
        except StandardError, e:
            info_list = ['Can not connect to controller ip = %s' % \
                         self._server_host,
                         'Possible reasons: ',
                         '(i) Maybe your auth_keys in server.manifest files are in error! please make sure the auth_keys are consistent '
                         'with the output from agent_token command executed on controller node.',
                         '(ii) Check if you can ping this host',
                         e]
            error(info_list)

        self._secs_list = self._storage_class_list + ['vsm_controller_ip', 'role', 'auth_key', 'zone']
        self.check_sections()
        self.check_all_disks()
        self.check_network()
        self.check_role()
        self._lan_list = []

    def check_role(self):
        roles = self.__lines_after_section('role')
        for r in roles:
            if r.find('storage') != -1:
                continue
            if r.find('monitor') != -1:
                continue
            if r.find('mds') != -1:
                continue
            if r.find('rgw') != -1:
                continue
            if r.find('unspecified') != -1:
                continue

            error('Can not define this role %s for VSM.' % r)

    def check_all_disks(self):
        secs_list = self._storage_class_list
        osd_list = []
        js_list = []
        for sec in secs_list:
            ls = self.__lines_after_section(sec)
            print 'sec %s has lines: %s' % (sec, ls)
            for line in ls:
                # must have osd and journal
                if len(line.split()) != 2:
                    print line.split()
                    error('Device = %s defined error!' % line)

                osd_list.append(line.split()[0])
                js_list.append(line.split()[1])

        osd_num = len(osd_list)
        js_num = len(js_list)

        if (osd_num + js_num) != len(list(set(osd_list+js_list))):
            error('There are duplicated osd/journal')

        if osd_num != len(list(set(osd_list))):
            print osd_list
            error('There are duplicated disks for OSD')

        if js_num != len(list(set(js_list))):
            print js_list
            error('There are duplicated disks for Journal!')

        if len(osd_list) != len(js_list):
            error('The OSD devices is not equal to journal devices!')



        self.__check_osd_devices(osd_list)
        self.__check_journal_devices(js_list)

    def __mount_devices(self):
        out = utils.execute('mount')
        return out

    def __check_is_mounted(self, osd):
        out = self.__mount_devices()
        for n in out[0].split('\n'):
            parted = n.split()
            if len(parted) < 3:
                continue
            mount_point = parted[2]
            if mount_point == osd:
                return True

        return False

    def __check_osd_devices(self, osd_list):
        self.__check_journal_devices(osd_list)

    def __check_journal_devices(self, js_list):
        print js_list
        for js in js_list:
            #if js.find('by-path') == -1:
            #    print "[WARNING]: do not using by-path path"

            if self.__check_is_mounted(js):
                error('You should not mount journal devices.!')

            if os.path.exists(js) == False:
                error('Can not find the journal devices in system path.!')

    def __lines_after_section(self, sec):
        lines = []
        find_it = False
        for line in self._lines:
            line = line.strip()
            if line.startswith('#'):
                continue

            if find_it and line.find('[') != -1 and line.find(']') != -1:
                print line
                break

            if len(line) < 2:
#                print 'Small line, skiped'
                continue

            if line.find('[') != -1 and\
                line.find(']') != -1 and\
                line.find(sec) != -1:
                find_it = True
                continue


            if find_it:
                lines.append(line)

        return lines

    def __is_address_ok(self, address):
        parted = address.split('.')
        if len(parted) != 4:
            info_list = ['There are some errors in LAN address!',
                         'address should like *.*.*.*/*',
                         'Your address is %s' % address ]
            error(info_list)

        utils.execute('ping', '-c', '4', address)

    def __get_section(self, line):
        if line.startswith('#'):
            return None

        if line.find('[') == -1 and line.find(']') == -1:
            return None

        line = line.strip()
        line = line.replace('[', ' ')
        line = line.replace(']', ' ')
        line = line.strip()
        print 'section = %s' % line
        return line

    def _get_all_section(self):
        sec_list = []
        for line in self._lines:
            line = line.strip()
            sec = self.__get_section(line)
            if sec:
                content = self._smp._get_segment(sec)
                if len(content['single']) != 0 or len(content['first']) != 0:
                    sec_list.append(sec)
        return sec_list

    def check_sections(self):
        sec_list = self._get_all_section()
        print sec_list
        print self._secs_list
        for x in sec_list:
            find_it = False
            if x in self._secs_list:
                find_it = True
            if find_it == False:
                error('Can not find section %s in DB!' % x)

    def _is_in_lan(self, ip, ip_mask):
        """Decide three networks:
                public addresss,
                secondary public address,
                cluster address
        """
        return True
        if ip in ipcalc.Network(ip_mask):
            return True
        else:
            return False

    def __in_lan(self, ip):
        management_network = \
            self._rec_data['cluster']['management_network']
        ceph_public_network = \
            self._rec_data['cluster']['ceph_public_network']
        cluster_network = self._rec_data['cluster']['cluster_network']

        self._lan_list = [management_network,
                          ceph_public_network,
                          cluster_network]
        lan_list = management_network.split(',') + \
                          ceph_public_network.split(',') + \
                          cluster_network.split(',')
        for _lan in lan_list:
            if self._is_in_lan(ip, _lan):
                return True

        if self._is_in_lan(ip, management_network):
            return True
        if self._is_in_lan(ip, ceph_public_network):
            return True
        if self._is_in_lan(ip, cluster_network):
            return True

        return False

    def check_network(self):
        ip = self._info['vsm_controller_ip']
        if False == self.__in_lan(ip):
            error('Controller ip is not in lan!')

        out = utils.execute('hostname', '-I')
        ip_list = out[0].strip().split()
        ip_in_lan_cnt = 0
        ip_in_lan = []
        for ip in ip_list:
            print ip_list
            print self._lan_list
            if self.__in_lan(ip):
                self.__is_address_ok(ip)
                ip_in_lan.append(ip)
        ip_in_lan_cnt = len(set(ip_in_lan))

        if ip_in_lan_cnt < len(set(self._lan_list)):
            info_list = ['Controller ip is not in LAN defined in',
                        'controller node /etc/manifest/cluster.manifest',
                        'Your IP = %s' % ip_list,
                        'Controller ip lan = %s' % self._lan_list]
            for info in info_list:
                print info
            error(info_list)
        if ip_in_lan_cnt < 3:
            info_list = [
            'Warning: There are some networks be the same.',
            'Controller ip lan = %s' % self._lan_list]
            for info in info_list:
                print info

        #self.elf.__is_address_ok(ip)

    def format_to_json(self):
        return json.dumps(self._info, sort_keys=True, indent=2)

if __name__ == '__main__':
    fpath = '/etc/manifest/server.manifest'
    if len(sys.argv) > 1 :
        fpath = sys.argv[1]
    if os.path.exists(fpath) == False:
        error('%s is not exists.' % fpath)
    smp = ManifestChecker(fpath)
    print '--------------------------------------'
    print 'Check Success ~~'
    print smp.format_to_json()
    print 'Check Success ~~'
    print '--------------------------------------'
